import json
import time
import hydra
import asyncio
from rich import print
from omegaconf import DictConfig, OmegaConf

from envs.multiagent_env import MultiAgentEnv
from envs.utils.config import get_env_args
from utils.rollout import rollout
from policies import construct_policy_map
from utils.metrics import Metrics

def eval_single_env(config, policies):
    # Create environment
    env = MultiAgentEnv(config)

    # Build metrics reporter for policies
    metrics = Metrics(policies)

    # Start Evaluation
    episode_returns = rollout(env, policies,
                              num_episodes=config.num_eval_episodes,
                              metrics=metrics,
                              logdir=config.logdir)

    # Report metrics
    metric_report = metrics.report()

    # Close environment
    env.close()
    return metric_report

@hydra.main(version_base=None, config_path="../configs", config_name="config")
def eval(config):
    print(OmegaConf.to_yaml(config))
    
    # Get all evaluation environment configurations
    eval_env_configs = config.env_config.get('eval_env_configs', None)
    
    # Generate policy mappings
    policies = construct_policy_map(config.exp_config)
    # Load checkpoint for each policy
    for policy_name, policy in policies.items():
        policy.load(ckpt_num=config.checkpoint)

    # Main evaluation loop for all env configurations
    metric_reports = {}
    start_time = time.time()
    if eval_env_configs is not None:
        for eval_env_config in eval_env_configs:
            config.env_config.scenario_config = eval_env_config
            metric_report = eval_single_env(config, policies)
            metric_reports[eval_env_config.split("/")[-1]] = metric_report
    else:
        # If no specific eval config is provided, evaluate the default environment
        metric_report = eval_single_env(config, policies)
        metric_reports[config.env_config.scenario_config.split("/")[-1]] = metric_report
    print(metric_reports)
    json.dump(metric_reports, open(config.logdir + "/metric_reports.json", "w"))
    print(f"Total evaluation time: {time.time() - start_time} seconds")

if __name__ == '__main__':
    eval()
